"""
SecMesh: A module to mesh the cross-section with triangular fibers

Author: Yexiang Yan
Email: yexiang_yan@outlook.com
WeChat: yx592715024
"""
import sys
import numpy as np
import gmsh
from shapely.geometry import Polygon, LineString
from shapely import affinity
import matplotlib.pyplot as plt
from matplotlib import collections

occ = gmsh.model.occ


def offset(points: list[list[float, float]], d: float):
    """偏移封闭边框

    Args:
        points (list[list[float, float]]): 包含坐标点的列表
        d (float): 正值向内偏移, 负值向外偏移

    Returns:
        list[[float, float]]: _description_
    """
    ply = Polygon(points)
    ply_off = ply.buffer(-d, cap_style=3, join_style=2)
    return list(ply_off.exterior.coords)


def rotate(points: list[list[float, float]], theta: float = 0, origin='center', use_radians=False):
    line = LineString(points)
    rotated_line = affinity.rotate(line, theta,
                                   origin=origin, use_radians=use_radians)
    return list(rotated_line.coords)


def translate(points: list[list[float, float]], dx: float = 0.0, dy: float = 0.0):
    line = LineString(points)
    tran_line = affinity.translate(line, xoff=dx, yoff=dy)
    return list(tran_line.coords)


def extract_mesh(dim, tag):
    # extract point coords
    idx, points, _ = gmsh.model.mesh.getNodes()
    points = np.asarray(points).reshape(-1, 3)
    srt = np.argsort(idx)
    idx -= 1
    assert np.all(idx[srt] == np.arange(len(idx)))
    points = points[srt]  # all points

    # extract cells in an entry
    elem_types, elem_tags, node_tags = gmsh.model.mesh.getElements(dim, tag)
    node_tags_sorted = None
    for elem_type, elem_tags, node_tags in zip(elem_types, elem_tags, node_tags):
        # `elementName', `dim', `order', `numNodes', `localNodeCoord',
        # `numPrimaryNodes'
        num_nodes_per_cell = gmsh.model.mesh.getElementProperties(elem_type)[3]

        node_tags_reshaped = np.asarray(node_tags).reshape(-1, num_nodes_per_cell) - 1
        node_tags_sorted = node_tags_reshaped[np.argsort(elem_tags)]
    cells = node_tags_sorted

    return points, cells


def lines_subdivide(x, y, gap):
    """
    The polylines consisting of coordinates x and y are divided by the gap.
    """
    x_new = []
    y_new = []
    for i in range(len(x) - 1):
        x1, y1 = x[i], y[i]
        x2, y2 = x[i + 1], y[i + 1]
        length = np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
        n = int(np.ceil(length / gap))
        x_new.extend(np.linspace(
            x1, x2, n, endpoint=True)[:-1].tolist())
        y_new.extend(np.linspace(
            y1, y2, n, endpoint=True)[:-1].tolist())
    x_new.append(x[-1])
    y_new.append(y[-1])
    new_line = np.column_stack((x_new, y_new))
    return new_line


def points2curvesloop(points, mesh_size):
    ply = Polygon(points)
    points = list(ply.exterior.coords)
    N = len(points) - 1
    nodeTags = []
    lineTags = []
    for i in range(N):
        x, y = points[i]
        p = occ.addPoint(x, y, 0, meshSize=mesh_size, tag=-1)
        nodeTags.append(p)
    for i in range(N):
        if i < N - 1:
            linetag = occ.addLine(nodeTags[i], nodeTags[i + 1], tag=-1)
        else:
            linetag = occ.addLine(nodeTags[i], nodeTags[0], tag=-1)
        lineTags.append(linetag)
    cl = occ.addCurveLoop(lineTags)
    return cl


def sec_rotation(x, y, theta):
    """
    Rotate the section coordinates counterclockwise by theta
    """
    x_new = x * np.cos(theta) + y * np.sin(theta)
    y_new = -x * np.sin(theta) + y * np.cos(theta)
    return x_new, y_new


class _geom_base:
    def __init__(self):
        self.dim = 2
        self.obj_tag = None
        self.obj_dim_tags = None
        self.obj_group_name = None
        self.obj_type = None

    def __add__(self, other):
        return self.union(other)

    def __sub__(self, other):
        return self.cut(other)

    def union(self, other):
        self.obj_dim_tags, _ = occ.fuse(self.obj_dim_tags, other.obj_dim_tags)
        return self

    def cut(self, other):
        self.obj_dim_tags, _ = occ.cut(self.obj_dim_tags, other.obj_dim_tags,
                                       tag=-1, removeObject=True, removeTool=True)
        return self

    def intersect(self, other):
        self.obj_dim_tags, _ = occ.intersect(self.obj_dim_tags, other.obj_dim_tags,
                                             tag=-1, removeObject=True, removeTool=True)
        return self

    def rotate(self, theta):
        theta = theta / 180 * np.pi
        for dim, tag in self.obj_dim_tags:
            occ.rotate([(dim, tag)], 0, 0, 0, 0, 0, 1, theta)


class _Polygon(_geom_base):
    def __init__(self, points: list[list[float, float]], mesh_size,
                 holes: list[list[list[float, float]]] = None):
        super().__init__()
        self.obj_type = 'gmsh'

        # exterior
        exterior_tag = points2curvesloop(points, mesh_size)
        # holes
        holes_tag = []
        if holes:
            for hole_points in holes:
                tag = points2curvesloop(hole_points, mesh_size)
                holes_tag.append(tag)
        self.obj_dim_tags = [(self.dim, occ.addPlaneSurface([exterior_tag] + holes_tag))]


class _Circle(_geom_base):
    def __init__(self, xo: list[float, float], radius: float, mesh_size: float,
                 holes=None, angle1=0., angle2=360):
        super().__init__()
        angle1 = angle1 / 180 * np.pi
        angle2 = angle2 / 180 * np.pi
        x, y, z = xo[0], xo[1], 0
        arc = occ.addCircle(x, y, z, radius, tag=-1, angle1=angle1, angle2=angle2)
        exterior_tag = occ.add_curve_loop([arc])

        # holes
        holes_tag = []
        if holes:
            for hole_points in holes:
                tag = points2curvesloop(hole_points, mesh_size)
                holes_tag.append(tag)
        self.obj_dim_tags = [(self.dim, occ.addPlaneSurface([exterior_tag] + holes_tag))]


class SecMesh:
    def __init__(self, sec_name: str = 'My Section'):
        super().__init__()
        gmsh.initialize()
        self.sec_name = sec_name
        gmsh.model.add(sec_name)

        self.group_tag_map = dict()
        self.mat_tag_map = dict()
        self.color_map = dict()
        self.colors_default = ['#037ef3', '#f85a40', '#00c16e', '#7552cc',
                               '#0cb9c1', '#f48924', '#f48924', '#52565e']
        self.cells_map = dict()
        self.points = None
        self.centers_map = dict()
        self.areas_map = dict()
        self.sec_property = dict()

        self.is_centring = False

        # 保存所有钢筋的信息
        self.rebar_data = []

    def assign_group(self, group: dict):
        for name, obj in group.items():
            obj.obj_group_name = name
            tags = [dim_tag[1] for dim_tag in obj.obj_dim_tags]
            self.group_tag_map[name] = tags
        return self

    def assign_group_mat(self, mat: dict):
        if not self.group_tag_map:
            raise ValueError(f"应先运行assign_group方法")
        for name in mat.keys():
            if name not in self.group_tag_map.keys():
                raise ValueError(f"{name}没有在assign_group函数中指定")
        for name, mat_tag in mat.items():
            self.mat_tag_map[name] = mat_tag
        return self

    def assign_group_color(self, color: dict):
        if not self.group_tag_map:
            raise ValueError(f"应先运行assign_group方法")
        for name in color.keys():
            if name not in self.group_tag_map.keys():
                raise ValueError(f"{name}没有在assign_group函数中指定")
        for name, c in color.items():
            self.color_map[name] = c
        return self

    def _get_cell_data(self):
        for name, tags in self.group_tag_map.items():
            temp_cell = []
            for tag in tags:
                points, cell_ = extract_mesh(2, tag)
                temp_cell.append(cell_)
            self.cells_map[name] = np.vstack(temp_cell)
        self.points = points[:, :2]
        # 每个纤维的坐标与面积
        vertices = self.points
        for name, faces in self.cells_map.items():
            areas = []
            centers = []
            for face in faces:
                idx1, idx2, idx3 = face
                coord1, coord2, coord3 = vertices[idx1], vertices[idx2], vertices[idx3]
                xyo = (coord1 + coord2 + coord3) / 3
                centers.append(xyo)
                x1, y1 = coord1[:2]
                x2, y2 = coord2[:2]
                x3, y3 = coord3[:2]
                area_ = 0.5 * np.abs(x2 * y3 + x1 * y2 + x3 * y1 - x3 * y2 - x2 * y1 - x1 * y3)
                areas.append(area_)
            self.areas_map[name] = np.array(areas)
            self.centers_map[name] = np.array(centers)

    def view(self, fill=True):
        if not self.color_map:
            for i, name in enumerate(self.group_tag_map.keys()):
                self.color_map[name] = self.colors_default[i]

        # matplotlib 绘图
        fig, ax = plt.subplots(figsize=(8, 8))
        # ax.set_facecolor("#efefef")
        # 显示划分图
        vertices = self.points  # 每个三角形顶点的号
        for name, faces in self.cells_map.items():
            # faces = faces.astype(np.int64)
            if not fill:
                x = vertices[:, 0]
                y = vertices[:, 1]
                ax.triplot(x, y, triangles=faces, color=self.color_map[name],
                           lw=1, zorder=-10)
                ax.plot([], [], '^', label=name, 
                        mec=self.color_map[name], mfc='white')  # 仅用作图例
            else:
                x = vertices[:, 0]
                y = vertices[:, 1]
                ax.triplot(x, y, triangles=faces, lw=0.75, color='#516572')
                patches = [plt.Polygon(vertices[face_link, :2], True)
                           for face_link in faces]
                coll = collections.PatchCollection(patches, facecolors=self.color_map[name],
                                                   edgecolors='#516572', linewidths=0.75,
                                                   zorder=-10)
                ax.add_collection(coll)
                ax.plot([], [], '^', label=name, color=self.color_map[name])  # 仅用作图例

        for data in self.rebar_data:
            color = data['color']
            name = data['name']
            rebar_xy = data['rebar_xy']
            dia = data['dia']
            rebar_coords = []
            rebar_areas = []
            for xy in rebar_xy:
                rebar_coords.append(xy)
                rebar_areas.append(np.pi / 4 * dia ** 2)
            patches = [plt.Circle((xy[0], xy[1]), np.sqrt(area / np.pi))
                       for xy, area in zip(rebar_coords, rebar_areas)]
            coll = collections.PatchCollection(
                patches, facecolors=color)
            ax.add_collection(coll)

        ax.set_aspect('equal')
        ax.set_title(self.sec_name, fontsize=26, fontfamily='SimSun')
        plt.xticks(fontsize=15)
        plt.yticks(fontsize=15)
        ax.legend(fontsize=18, shadow=True, markerscale=3, loc=10, ncol=len(self.group_tag_map),
                  bbox_to_anchor=(0.5, -0.1), bbox_transform=ax.transAxes)
        ax.tick_params(labelsize=18)
        plt.show()

    def sec_props(self):
        if self.cells_map == dict():
            raise RuntimeError(f"计算截面特性前必须先运行mesh函数划分截面!")
        # 整体的几何性质
        fiber_centers = []
        fiber_areas = []
        for name in self.cells_map.keys():
            fiber_centers.append(self.centers_map[name])
            fiber_areas.append(self.areas_map[name])
        fiber_centers = np.vstack(fiber_centers)
        fiber_areas = np.hstack(fiber_areas)
        # 求总面积
        Area = np.sum(fiber_areas)
        # 求形心坐标
        yo = np.sum(fiber_centers[:, 0] * fiber_areas) / Area
        zo = np.sum(fiber_centers[:, 1] * fiber_areas) / Area
        # 求截面惯性矩
        Iy = np.sum(fiber_centers[:, 1] ** 2 * fiber_areas)
        Iz = np.sum(fiber_centers[:, 0] ** 2 * fiber_areas)
        # 回转半径
        ry = np.sqrt(Iy / Area)
        rz = np.sqrt(Iz / Area)
        # 极惯性矩
        Ip = Iy + Iz
        # 惯性积
        Iyz = np.sum(fiber_centers[:, 0] * fiber_centers[:, 1] * fiber_areas)
        # 配筋率  # 钢筋
        if self.rebar_data:
            all_rebar_area = 0
            for data in self.rebar_data:
                rebar_xy = data['rebar_xy']
                dia = data['dia']
                rebar_coords = []
                rebar_areas = []
                for xy in rebar_xy:
                    rebar_coords.append(xy)
                    rebar_areas.append(np.pi / 4 * dia ** 2)
                all_rebar_area += np.sum(rebar_areas)
            rho_rebar = all_rebar_area / Area
        else:
            rho_rebar = 0
        # 集总
        sec_property = dict(area=Area, center=(yo, zo), Iy=Iy,
                            Iz=Iz, ry=ry, rz=rz, Ip=Ip, Iyz=Iyz, rho_rebar=rho_rebar)
        self.sec_property = sec_property
        return sec_property, fiber_centers, fiber_areas

    def centring(self):
        # 将截面移动到形心位置
        if self.sec_property == dict():
            sec_property, fiber_centers, fiber_areas = self.sec_props()
        else:
            sec_property = self.sec_property
        center = np.array(sec_property['center'])
        self.points -= center
        names = self.centers_map.keys()
        for name in names:
            self.centers_map[name] -= center
        # 钢筋
        for i, data in enumerate(self.rebar_data):
            self.rebar_data[i]['rebar_xy'] -= center
        self.is_centring = True

    def rotate(self, theta=0):
        # 顺时针转动截面
        # 先求截面形心
        theta = theta / 180 * np.pi  # 转换成弧度制

        if not self.is_centring:
            self.centring()

        x_rot, y_rot = sec_rotation(self.points[:, 0], self.points[:, 1], theta)  # 将各坐标转动theta
        self.points[:, 0], self.points[:, 1] = x_rot, y_rot

        names = self.centers_map.keys()
        for name in names:
            x_rot, y_rot = sec_rotation(self.centers_map[name][:, 0], self.centers_map[name][:, 1], theta)
            self.centers_map[name][:, 0], self.centers_map[name][:, 1] = x_rot, y_rot
        # 钢筋
        for i, data in enumerate(self.rebar_data):
            rebar_xy = self.rebar_data[i]['rebar_xy']
            x_rot, y_rot = sec_rotation(rebar_xy[:, 0], rebar_xy[:, 1], theta)
            self.rebar_data[i]['rebar_xy'][:, 0], self.rebar_data[i]['rebar_xy'][:, 1] = x_rot, y_rot

    def opspy_cmds(self, secTag: int, GJ: float):
        """
        生成opensees纤维命令
        :param secTag: 截面号
        :param GJ: 扭转常数
        :return: None
        """
        import openseespy.opensees as ops
        ops.section('Fiber', secTag, '-GJ', GJ)

        names = self.centers_map.keys()
        for name in names:
            centers = self.centers_map[name]
            areas = self.areas_map[name]
            matTag = self.mat_tag_map[name]
            for center, area in zip(centers, areas):
                ops.fiber(center[0], center[1], area, matTag)
        # 钢筋
        for data in self.rebar_data:
            rebar_xy = data['rebar_xy']
            dia = data['dia']
            matTag = data['matTag']
            for xy in rebar_xy:
                area = np.pi / 4 * dia ** 2
                ops.fiber(xy[0], xy[1], area, matTag)

    def to_file(self, output_path: str, secTag: int, GJ: float):

        if not (output_path.endswith('.tcl') or output_path.endswith('.py')):
            raise ValueError(f"output_path 必须以tcl或py结尾!")

        names = self.centers_map.keys()
        with open(output_path, "w+") as output:
            output.write('# This document was created from SecMesh \n')
            output.write('# Author: Yexiang Yan  yanswjtu@yeah.net \n\n')
            if output_path.endswith('.tcl'):
                output.write(f'set SecTag {secTag}\n')
                temp = "{"
                output.write(f'section fiberSec $secTag -GJ {GJ}{temp};    # Define the fiber section\n')
                for name in names:
                    centers = self.centers_map[name]
                    areas = self.areas_map[name]
                    matTag = self.mat_tag_map[name]
                    for center, area in zip(centers, areas):
                        output.write(f'    fiber {center[0]:.3f} {center[1]:.3f} {area:.3f} {matTag}\n')
                # 钢筋
                for data in self.rebar_data:
                    output.write('    # Define Rebar\n')
                    rebar_xy = data['rebar_xy']
                    dia = data['dia']
                    matTag = data['matTag']
                    for xy in rebar_xy:
                        area = np.pi / 4 * dia ** 2
                        output.write(f'    fiber {xy[0]:.3f} {xy[1]:.3f} {area:.3f} {matTag}\n')
                output.write('};    # end of fibersection definition')
            elif output_path.endswith('.py'):
                output.write('import openseespy.opensees as ops\n\n\n')
                output.write(f"ops.section('Fiber', {secTag}, '-GJ', {GJ})  # Define the fiber section\n")
                for name in names:
                    centers = self.centers_map[name]
                    areas = self.areas_map[name]
                    matTag = self.mat_tag_map[name]
                    for center, area in zip(centers, areas):
                        output.write(f'ops.fiber({center[0]:.3f}, {center[1]:.3f}, {area:.3f}, {matTag})\n')
                # 钢筋
                for data in self.rebar_data:
                    output.write('# Define Rebar\n')
                    rebar_xy = data['rebar_xy']
                    dia = data['dia']
                    matTag = data['matTag']
                    for xy in rebar_xy:
                        area = np.pi / 4 * dia ** 2
                        output.write(f'ops.fiber({xy[0]:.3f}, {xy[1]:.3f}, {area:.3f}, {matTag})\n')

    @staticmethod
    def add_polygon(points: list[list[float, float]], mesh_size,
                    holes: list[list[list[float, float]]] = None):
        return _Polygon(points, mesh_size, holes)

    @staticmethod
    def add_circle(xo: list[float, float], radius: float, mesh_size: float,
                   holes=None, angle1=0., angle2=360):
        return _Circle(xo, radius, mesh_size, holes, angle1, angle2)

    def add_rebar_line(self, points, dia, gap,
                       color='black', group_name=None, matTag=None):
        # 处理除最外层与最内层之外的额外的钢筋
        rebar_lines = LineString(points)
        x, y = rebar_lines.xy
        # 根据间距重新划分钢筋点
        rebar_xy = lines_subdivide(x, y, gap)
        data = dict(rebar_xy=rebar_xy, color=color,
                    name=group_name, dia=dia,
                    matTag=matTag)
        self.rebar_data.append(data)

    def mesh(self):
        gmsh.model.occ.synchronize()
        gmsh.model.mesh.generate(2)
        self._get_cell_data()
        self.close()

    @staticmethod
    def close():
        gmsh.finalize()


if __name__ == '__main__':

    # case 1
    sec = SecMesh(sec_name='桥墩底截面')
    outlines = [[0, 0], [0, 1], [1, 1], [1, 0]]
    outlines2 = offset(outlines, d=0.05)
    holes1 = [[0.2, 0.2], [0.4, 0.2], [0.4, 0.4], [0.2, 0.4]]
    holes2 = [[0.6, 0.2], [0.8, 0.2], [0.8, 0.4], [0.6, 0.4]]
    cover = sec.add_polygon(outlines, mesh_size=0.1, holes=[outlines2])
    # cover.rotate(theta=45)
    core = sec.add_polygon(outlines2, mesh_size=0.1, holes=[holes1, holes2])
    sec.assign_group({"cover": cover, "core": core})
    sec.assign_group_color({"cover": "#34bf49", "core": "#0099e5"})
    sec.assign_group_mat({"cover": 1, "core": 2})
    rebar_lines1 = offset(outlines, d=0.06 + 0.032 / 2)
    rebar_lines2 = [[0.2, 0.2], [0.8, 0.8]]
    sec.add_rebar_line(points=rebar_lines1, dia=0.032, gap=0.1, color='red', matTag=3)
    sec.add_rebar_line(points=rebar_lines2, dia=0.020, gap=0.1, color='black', matTag=3)
    sec.mesh()
    sec.rotate(0)
    sec.view(fill=False)
    sec.sec_props()
    sec.to_file(output_path='yan.py', secTag=1, GJ=1000)

    # case 2
    sec = SecMesh(sec_name='圆形截面')
    geom1 = sec.add_circle(xo=[0, 0], radius=1.5, mesh_size=0.02, holes=None, angle1=0., angle2=360)
    geom2 = sec.add_circle(xo=[0, 0], radius=1.5 - 0.06, mesh_size=0.02, holes=None, angle1=0., angle2=360)
    cover = geom1 - geom2
    holes = [(-0.4, -0.4), (0.4, -0.4), (0.4, 0.4), (-0.4, 0.4)]
    core = sec.add_circle(xo=[0, 0], radius=1.5 - 0.06, mesh_size=0.06, holes=[holes])
    sec.assign_group({"cover": cover, "core": core}).assign_group_color({"cover": "#34bf49", "core": "#0099e5"})
    sec.mesh()
    sec.view()

    # case 1
    sec = SecMesh(sec_name='我的截面')
    outlines = [(-4, 4), (-2, 4), (-2, 2), (2, 2), (2, 4), (4, 4),
                (4, -4), (2, -4), (2, -2), (-2, -2), (-2, -4), (-4, -4)]
    holes = [[(-3, 1), (-2, 1), (-2, -1), (-3, -1)],
             [(3, 1), (2, 1), (2, -1), (3, -1)]]

    cover_lines = offset(outlines, d=0.1)
    rebar_lines1 = offset(outlines, d=0.1 + 0.08 / 2)
    rebar_lines2 = offset(outlines, d=0.5 + 0.08 / 2)
    rebar_lines3 = [[-3, -3], [3, 3]]
    cover = sec.add_polygon(outlines, mesh_size=0.3, holes=[cover_lines])
    core = sec.add_polygon(cover_lines, mesh_size=0.5, holes=holes)
    sec.assign_group({"cover": cover, "core": core})
    sec.assign_group_color({"cover": "#34bf49", "core": "#0099e5"})
    sec.assign_group_mat({"cover": 1, "core": 2})
    sec.add_rebar_line(points=rebar_lines1, dia=0.08, gap=0.3, color='red', matTag=3)
    sec.add_rebar_line(points=rebar_lines2, dia=0.08, gap=0.3, color='purple', matTag=3)
    sec.add_rebar_line(points=rebar_lines3, dia=0.08, gap=0.3, color='green', matTag=3)
    sec.mesh()
    sec.rotate(0)
    sec.view()
    sec.sec_props()
    sec.to_file(output_path='yan.py', secTag=1, GJ=1000)
